import os
from tqdm import tqdm
import nltk
import numpy as np
import random
from collections import Counter
import re
from bpemb import BPEmb
from build_sppmi import build_sppmi_matrix

# hyper-parameter for dataprocess
min_review_len = 20
min_freq = 15
val_ratio = 0.05
# for SG
piece_vocab_size = 10000
bpemb_en = BPEmb(lang="en", dim=50, vs=10000)
# for sppmi
window_size = 7
shift = 10

rootdir = '..'
datadir = os.path.join(rootdir,'Dataset')

savedir = os.path.join('data_correct', 'Ama')
val_savedir = os.path.join('data_correct', 'val_Ama')

dict_savedir = 'Dict'
cor_savedir = 'Cor'
label_savedir = 'Label'
Dt_savedir = 'Dt'
label_dict_savedir = 'LabelDict'
St_savedir = 'subword'
sppmi_savedir = 'sppmi'


domain_order = {
	'Safety_Signs_and_Signals':0,
	'Pet_Behavior_Center':1,
	'SIM_Cards_and_Prepaid_Minutes':2,
	'Horses':3,
	'Video':4,
	'Keyboards_and_MIDI':5,
	'Characters_and_Series':6,
	'Cult_Movies':7,
	'Science_Education':8
}



# build_vocab
def build_vocab(paths):
	# calculate word frequency
	counter = Counter()
	for path in paths:
		print('processing',path,'...')
		with open(path,'r',encoding='utf8') as fr:
			for i,line in tqdm(enumerate(fr),desc='reading reviews for building vocab ...'):
			 	line = line.strip()
			 	if not line:
			 		continue
		 		
		 		words = nltk.word_tokenize(line)
		 		counter.update(words)
	
	# load stopwords
	stopwords_path = 'stopwords.dict'
	stopwords = []
	with open(stopwords_path,'r',encoding='utf8') as fr:
		for line in fr:
			word = line.strip()
			stopwords.append(word)

	vocab = {}
	word2id = {}
	id2word = []

	vocab_t = dict(counter)
	for word, cnt in tqdm(vocab_t.items(),desc='building vocab ...'):
		# filter low frequency words / _ / stopwords
		if (cnt >= min_freq) and ('_' not in word) and (word not in stopwords):
			vocab[word] = cnt
			word2id[word] = len(id2word)
			id2word.append(word)

	return vocab, word2id, id2word


# build_vocab with cor
def build_vocab_cor(cor):
	# calculate word frequency
	counter = Counter()
	for line in cor:
	 	line = line.strip()
	 	if not line:
	 		continue
 		
 		words = nltk.word_tokenize(line)
 		counter.update(words)

	vocab = {}
	word2id = {}
	id2word = []

	vocab_t = dict(counter)
	for word, cnt in tqdm(vocab_t.items(),desc='building vocab ...'):
		vocab[word] = cnt
		word2id[word] = len(id2word)
		id2word.append(word)

	return vocab, word2id, id2word


# freeze random seeds
def freeze_seed(seed=12345):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)


# save data
def save_data(savedir, cn, cor, label, Dt_t, X, N_t, id2word, word2id, label_dict):
	# save new_dict
	cn = re.sub(r'\s','_',cn)
	cn = re.sub(r'\&','and',cn)

	cn_idx = str(domain_order[cn])

	print('saving Id2word_'+cn, '...')
	dict_save_path = os.path.join(savedir,dict_savedir,'Id2word_'+cn_idx+'.txt')
	with open(dict_save_path,'w',newline='\n') as fw:
		fw.write('\n'.join(id2word))
		fw.write('\n')

	print('saving cor_'+cn, '...')
	cor_savepath = os.path.join(savedir,cor_savedir,'cor_'+cn_idx+'.txt')
	with open(cor_savepath,'w',newline='\n') as fw:
		fw.write('\n'.join(cor))

	print('saving label_'+cn, '...')
	label_savepath = os.path.join(savedir,label_savedir,'label_'+cn_idx+'.txt')
	with open(label_savepath,'w',newline='\n') as fw:
		fw.write('\n'.join(label))

	print('saving label_dict_'+cn, '...')
	label_dict_savepath = os.path.join(savedir,label_dict_savedir,'label_dict_'+cn_idx+'.txt')
	with open(label_dict_savepath,'w',newline='\n') as fw:
		fw.write('\n'.join(label_dict))

	#Dt_savepath = os.path.join(Dt_savedir,'X_'+cn+'.txt')
	Dt_npy_savepath = os.path.join(savedir,Dt_savedir,'X_'+cn_idx+'.npy')
	cor_size = len(cor)
	vocab_size = len(id2word)
	Dt = np.zeros((cor_size,vocab_size),dtype='int32')
	for i, dt in tqdm(enumerate(Dt_t),desc='building Dt ...'):
		for w, cnt in dt.items():
			Dt[i][word2id[w]] = cnt
	print('saving X_'+cn, '...')
	#np.savetxt(Dt_savepath, Dt)
	np.save(Dt_npy_savepath, Dt)


	print('saving X_Comp_relation_'+cn+' ...')
	St_savepath = os.path.join(savedir,St_savedir,'X_Comp_relation_'+cn_idx+'.npy')
	np.save(St_savepath, N_t)

	print('saving sppmi_'+cn+' ...')
	X_savepath = os.path.join(savedir,sppmi_savedir,'sppmi_'+cn_idx+'.npy')
	np.save(X_savepath, X)


def build_St(id2word):
	# build St using pretrained BPE

	vocab_size = len(id2word)

	id2wp = []
	wp2id = {}
	wp_id = 0
	wp_word = np.zeros((piece_vocab_size,vocab_size),dtype='int32')
	for wid, word in tqdm(enumerate(id2word),desc='building Nt ...'):
		wp_list = bpemb_en.encode(word)
		
		for wp in wp_list:
			if wp not in wp2id:
				wp2id[wp] = wp_id
				id2wp.append(wp)
				wp_id += 1
			
			wp_word[wp2id[wp]][wid] += 1


	N_t = np.eye(vocab_size, dtype='int32')
	for m, m2w in tqdm(enumerate(wp_word),desc='building Nt ...'):
		idx_list = np.nonzero(m2w)[0]
		for i in range(0, len(idx_list)-1):
			for j in range(i+1, len(idx_list)):
				idx_i = idx_list[i]
				idx_j = idx_list[j]
				value = min(m2w[idx_i], m2w[idx_j])
				N_t[idx_i][idx_j] += value
				N_t[idx_j][idx_i] += value

	# check
	for m, m2w in enumerate(wp_word):
		idx_list = np.nonzero(m2w)[0]
		out = [id2word[wid]+'('+str(m2w[wid])+')' for wid in idx_list]
		print(id2wp[m],':',' '.join(out))
		if m > 5:
			break

	return N_t


def main_process():

	# build total dict
	chunk_name = []
	chunk_paths = []
	chunk_list = os.walk(datadir)
	for root, dirs, files in chunk_list:
		for f in files:
			cn = f.split('_')[0]
			if cn not in chunk_name:
				chunk_name.append(cn)
				chunk_paths.append(os.path.join(root,cn+'_cor.txt'))

	_, total_word2id, total_id2word = build_vocab(chunk_paths)
	print('total dict size =',len(total_id2word))
	print('saving Id2word_total ...')
	total_dict_savepath = os.path.join(savedir,dict_savedir,'Id2word_total.txt')
	with open(total_dict_savepath,'w',newline='\n') as fw:
		fw.write('\n'.join(total_id2word))
		fw.write('\n')

	total_dict_savepath = os.path.join(val_savedir,dict_savedir,'Id2word_total.txt')
	with open(total_dict_savepath,'w',newline='\n') as fw:
		fw.write('\n'.join(total_id2word))
		fw.write('\n')

	# process each domain
	for cn in chunk_name:
		print('start to process',cn,'...')
		cor_path = os.path.join(datadir,cn+'_cor.txt')
		label_path = os.path.join(datadir,cn+'_label.txt')

		cor = []
		label = []
		Dt_t = []
		word2id = {}
		id2word = []
		label_dict = []

		fr_cor = open(cor_path,'r')
		fr_label = open(label_path,'r')
		for c, l in tqdm(zip(fr_cor,fr_label), desc='filtering low frequency words ...'):
			c = c.strip()
			l = l.strip()

			words = nltk.word_tokenize(c)
			words = [word for word in words if word in total_word2id]
			# filter reviews with less than `min_review_len` words
			if len(words) > min_review_len:
				for word in words:
					if word not in word2id:
						word2id[word] = len(id2word)
						id2word.append(word)

				cor.append(' '.join(words))
				label.append(l)
				if l not in label_dict:
					label_dict.append(l)

				counter = Counter(words)
				Dt_t.append(dict(counter))
		
		fr_cor.close()
		fr_label.close()

		vocab_size = len(id2word)
		print('vocab size =',vocab_size)
		cor_size = len(cor)
		print('cor size =',cor_size)
		label_size = len(label_dict)
		print('label size =',label_size)


		if cn != 'Science Education':
			# segment train/val dataset
			idx_list = np.arange(cor_size)
			np.random.shuffle(idx_list)

			val_size = int(val_ratio*cor_size)
			val_cor = [cor[i] for i in idx_list[:val_size]]
			val_label = [label[i] for i in idx_list[:val_size]]
			val_Dt = [Dt_t[i] for i in idx_list[:val_size]]
			_, val_word2id, val_id2word = build_vocab_cor(val_cor)
			val_vocab_size = len(val_id2word)

			# build sppmi matrix
			val_X = build_sppmi_matrix(val_cor, val_word2id, val_vocab_size, window_size, shift)

			# build S_t
			val_St = build_St(val_id2word)

			save_data(val_savedir, cn, val_cor, val_label, val_Dt, val_X, val_St, val_id2word, val_word2id, label_dict)


			train_cor = [cor[i] for i in idx_list[val_size:]]
			train_label = [label[i] for i in idx_list[val_size:]]
			train_Dt = [Dt_t[i] for i in idx_list[val_size:]]
			_, train_word2id, train_id2word = build_vocab_cor(train_cor)
			train_vocab_size = len(train_id2word)

			# build S_t
			train_St = build_St(train_id2word)

			# build sppmi matrix
			train_X = build_sppmi_matrix(train_cor, train_word2id, train_vocab_size, window_size, shift)
			save_data(savedir, cn, train_cor, train_label, train_Dt, train_X, train_St, train_id2word, train_word2id, label_dict)

		else:
			# build sppmi matrix
			X = build_sppmi_matrix(cor, word2id, vocab_size, window_size, shift)

			# build S_t
			St = build_St(id2word)

			save_data(savedir, cn, cor, label, Dt_t, X, St, id2word, word2id, label_dict)


def main():
	# freeze random seed
	freeze_seed()

	main_process()
	

if __name__ == '__main__':
	main()
